import jsonlines
import torch
import sys
from tqdm import tqdm
import math


path = sys.argv[1]
data = [d for d in jsonlines.open(path, "r")]


for d in tqdm(data):
    for key in d['K']:
        p, q = torch.FloatTensor(d['K'][key]), torch.FloatTensor(d['K']['vanilla'])
        p = p[:len(q)]
        kl_loss = torch.sum(p * (p.log() - q.log()), dim=-1).item()
        if not math.isnan(kl_loss) and not math.isinf(kl_loss) and not kl_loss < 0.0:
            d['K'][key] = kl_loss
        else:
            d['K'][key] = 0.0

fo = jsonlines.open(path.replace('probs', 'kl'), 'w')
fo.write_all(data)








